{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eB-dvsMI09O4"
      },
      "source": [
        "##### Copyright 2024 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "mwvC53CC1K3n"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w61LHyw7v_yx"
      },
      "source": [
        "Converting Keras to TFLite (via the JAX backend)\n",
        "=========="
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zQXWQ7y11eIR"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/examples/keras/keras_jax_backend_to_tfl\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2-t7bCE0lGsH"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "os.environ[\"KERAS_BACKEND\"] = \"jax\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SNUGBsILwSSs"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ht0UjgDxliW9"
      },
      "outputs": [],
      "source": [
        "import keras\n",
        "import tensorflow as tf\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kZhJOer0wXWP"
      },
      "source": [
        "## Get the test image data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FirUqiycez0X"
      },
      "outputs": [],
      "source": [
        "from PIL import Image\n",
        "import requests\n",
        "\n",
        "url = \"https://storage.googleapis.com/download.tensorflow.org/example_images/astrid_l_shaped.jpg\"\n",
        "image = Image.open(requests.get(url, stream=True).raw)\n",
        "image = image.resize((224, 224))\n",
        "input_image = np.array(image)\n",
        "input_image = np.expand_dims(input_image, axis=0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CHMbP6ZWwcV_"
      },
      "source": [
        "## Instatiate a Resnet50 model from the Keras models library"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CbJyUj1IoqF6"
      },
      "outputs": [],
      "source": [
        "jax_model = keras.applications.resnet.ResNet50(include_top=True, weights=\"imagenet\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B5yqkEEXwo13"
      },
      "source": [
        "## Run the keras JAX model with the test input"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IBHBetUqfhDA"
      },
      "outputs": [],
      "source": [
        "input_data = keras.applications.resnet50.preprocess_input(input_image)\n",
        "jax_model_output = jax_model(input_data)\n",
        "\n",
        "decoded_preds = keras.applications.resnet.decode_predictions(jax_model_output, top=1)[\n",
        "    0\n",
        "][0]\n",
        "print(\"Predicted class:\", decoded_preds[1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HaUJGHGpw0KD"
      },
      "source": [
        "## Save the Keras JAX model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IZ4YqZLTrGc6"
      },
      "outputs": [],
      "source": [
        "saved_model_dir = \"resnet50_saved_model\"\n",
        "jax_model.export(saved_model_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bNSGGXfhw4uO"
      },
      "source": [
        "## Convert to a TFLite model file"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MdJz2eKqsEhA"
      },
      "outputs": [],
      "source": [
        "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n",
        "tflite_model = converter.convert()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NZ4DjfGSWS7O"
      },
      "source": [
        "## Run using TFLite Runtime"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CtMSYAkwWWVm"
      },
      "outputs": [],
      "source": [
        "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n",
        "interpreter.allocate_tensors()\n",
        "\n",
        "input_details = interpreter.get_input_details()[0]\n",
        "interpreter.set_tensor(input_details[\"index\"], input_data)\n",
        "interpreter.invoke()\n",
        "\n",
        "output_details = interpreter.get_output_details()\n",
        "output_data = interpreter.get_tensor(output_details[0][\"index\"])\n",
        "\n",
        "tfl_predicted_class_idx = keras.applications.resnet.decode_predictions(\n",
        "    output_data, top=1\n",
        ")[0][0]\n",
        "print(\"Predicted class:\", tfl_predicted_class_idx[1])"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "keras_jax_backend_to_tfl.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
